| #!/usr/bin/env python3 |
| |
| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Custom mmi2grpc gRPC compiler.""" |
| |
| import sys |
| |
| from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, \ |
| CodeGeneratorResponse |
| |
| |
| def eprint(*args, **kwargs): |
| print(*args, file=sys.stderr, **kwargs) |
| |
| |
| request = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) |
| |
| |
| def has_type(proto_file, type_name): |
| return any(filter(lambda x: x.name == type_name, proto_file.message_type)) |
| |
| |
| def import_type(imports, type): |
| package = type[1:type.rindex('.')] |
| type_name = type[type.rindex('.')+1:] |
| file = next(filter( |
| lambda x: x.package == package and has_type(x, type_name), |
| request.proto_file)) |
| python_path = file.name.replace('.proto', '').replace('/', '.') |
| as_name = python_path.replace('.', '_dot_') + '__pb2' |
| module_path = python_path[:python_path.rindex('.')] |
| module_name = python_path[python_path.rindex('.')+1:] + '_pb2' |
| imports.add(f'from {module_path} import {module_name} as {as_name}') |
| return f'{as_name}.{type_name}' |
| |
| |
| def generate_service_method(imports, file, service, method): |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type = import_type(imports, method.input_type) |
| output_type = import_type(imports, method.output_type) |
| |
| if input_mode == 'stream': |
| return ( |
| f'def {method.name}(self, iterator, **kwargs):\n' |
| f' return self.channel.{input_mode}_{output_mode}(\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type}.SerializeToString,\n' |
| f' response_deserializer={output_type}.FromString\n' |
| f' )(iterator, **kwargs)' |
| ).split('\n') |
| else: |
| return ( |
| f'def {method.name}(self, wait_for_ready=None, **kwargs):\n' |
| f' return self.channel.{input_mode}_{output_mode}(\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type}.SerializeToString,\n' |
| f' response_deserializer={output_type}.FromString\n' |
| f' )({input_type}(**kwargs), wait_for_ready=wait_for_ready)' |
| ).split('\n') |
| |
| |
| def generate_service(imports, file, service): |
| methods = '\n\n '.join([ |
| '\n '.join( |
| generate_service_method(imports, file, service, method) |
| ) for method in service.method |
| ]) |
| return ( |
| f'class {service.name}:\n' |
| f' def __init__(self, channel):\n' |
| f' self.channel = channel\n' |
| f'\n' |
| f' {methods}\n' |
| ).split('\n') |
| |
| |
| def generate_servicer_method(method): |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| |
| if input_mode == 'stream': |
| return ( |
| f'def {method.name}(self, request_iterator, context):\n' |
| f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' |
| f' context.set_details("Method not implemented!")\n' |
| f' raise NotImplementedError("Method not implemented!")' |
| ).split('\n') |
| else: |
| return ( |
| f'def {method.name}(self, request, context):\n' |
| f' context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n' |
| f' context.set_details("Method not implemented!")\n' |
| f' raise NotImplementedError("Method not implemented!")' |
| ).split('\n') |
| |
| |
| def generate_servicer(service): |
| methods = '\n\n '.join([ |
| '\n '.join( |
| generate_servicer_method(method) |
| ) for method in service.method |
| ]) |
| return ( |
| f'class {service.name}Servicer:\n' |
| f'\n' |
| f' {methods}\n' |
| ).split('\n') |
| |
| |
| def generate_rpc_method_handler(imports, method): |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type = import_type(imports, method.input_type) |
| output_type = import_type(imports, method.output_type) |
| |
| return ( |
| f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler(\n" |
| f' servicer.{method.name},\n' |
| f' request_deserializer={input_type}.FromString,\n' |
| f' response_serializer={output_type}.SerializeToString,\n' |
| f' ),\n' |
| ).split('\n') |
| |
| |
| def generate_add_servicer_to_server_method(imports, file, service): |
| method_handlers = ' '.join([ |
| '\n '.join( |
| generate_rpc_method_handler(imports, method) |
| ) for method in service.method |
| ]) |
| return ( |
| f'def add_{service.name}Servicer_to_server(servicer, server):\n' |
| f' rpc_method_handlers = {{\n' |
| f' {method_handlers}\n' |
| f' }}\n' |
| f' generic_handler = grpc.method_handlers_generic_handler(\n' |
| f" '{file.package}.{service.name}', rpc_method_handlers)\n" |
| f' server.add_generic_rpc_handlers((generic_handler,))' |
| ).split('\n') |
| |
| |
| files = [] |
| |
| for file_name in request.file_to_generate: |
| file = next(filter(lambda x: x.name == file_name, request.proto_file)) |
| |
| imports = set(['import grpc']) |
| |
| services = '\n'.join(sum([ |
| generate_service(imports, file, service) for service in file.service |
| ], [])) |
| |
| servicers = '\n'.join(sum([ |
| generate_servicer(service) for service in file.service |
| ], [])) |
| |
| add_servicer_methods = '\n'.join(sum([ |
| generate_add_servicer_to_server_method(imports, file, service) for service in file.service |
| ], [])) |
| |
| files.append(CodeGeneratorResponse.File( |
| name=file_name.replace('.proto', '_grpc.py'), |
| content='\n'.join(imports) + '\n\n' + services + '\n\n' + servicers + '\n\n' + add_servicer_methods + '\n' |
| )) |
| |
| response = CodeGeneratorResponse(file=files) |
| |
| sys.stdout.buffer.write(response.SerializeToString()) |