| #!/usr/bin/env python3 |
| |
| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Custom mmi2grpc gRPC compiler.""" |
| |
| import os |
| import sys |
| |
| from typing import Dict, List, Optional, Set, Tuple, Union |
| |
| from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse |
| from google.protobuf.descriptor import ( |
| FieldDescriptor |
| ) |
| from google.protobuf.descriptor_pb2 import ( |
| FileDescriptorProto, |
| EnumDescriptorProto, |
| DescriptorProto, |
| ServiceDescriptorProto, |
| MethodDescriptorProto, |
| FieldDescriptorProto, |
| ) |
| |
| _REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read()) |
| |
| |
| def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]: |
| for enum in proto_file.enum_type: |
| if enum.name == type_name: |
| return enum |
| for message in proto_file.message_type: |
| if message.name == type_name: |
| return message |
| return None |
| |
| |
| def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]: |
| for file in _REQUEST.proto_file: |
| if file.package == package and (type := find_type_in_file(file, type_name)): |
| return file, type |
| raise Exception(f'Type {package}.{type_name} not found') |
| |
| |
| def add_import(imports: List[str], import_str: str) -> None: |
| if not import_str in imports: |
| imports.append(import_str) |
| |
| |
| def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]: |
| package = type[1:type.rindex('.')] |
| type_name = type[type.rindex('.')+1:] |
| file, desc = find_type(package, type_name) |
| if file == local: |
| return f'{type_name}', desc, '' |
| python_path = file.name.replace('.proto', '').replace('/', '.') |
| module_path = python_path[:python_path.rindex('.')] |
| module_name = python_path[python_path.rindex('.')+1:] + '_pb2' |
| add_import(imports, f'from {module_path} import {module_name}') |
| dft_import = '' |
| if isinstance(desc, EnumDescriptorProto): |
| dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}' |
| return f'{module_name}.{type_name}', desc, dft_import |
| |
| |
| def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]: |
| dft: str |
| dft_import: str = '' |
| if field.type == FieldDescriptor.TYPE_BYTES: |
| type = 'bytes' |
| dft = 'b\'\'' |
| elif field.type == FieldDescriptor.TYPE_STRING: |
| type = 'str' |
| dft = '\'\'' |
| elif field.type == FieldDescriptor.TYPE_BOOL: |
| type = 'bool' |
| dft = 'False' |
| elif field.type in [ |
| FieldDescriptor.TYPE_FLOAT, |
| FieldDescriptor.TYPE_DOUBLE |
| ]: |
| type = 'float' |
| dft = '0.0' |
| elif field.type in [ |
| FieldDescriptor.TYPE_INT64, |
| FieldDescriptor.TYPE_UINT64, |
| FieldDescriptor.TYPE_INT32, |
| FieldDescriptor.TYPE_FIXED64, |
| FieldDescriptor.TYPE_FIXED32, |
| FieldDescriptor.TYPE_UINT32, |
| FieldDescriptor.TYPE_SFIXED32, |
| FieldDescriptor.TYPE_SFIXED64, |
| FieldDescriptor.TYPE_SINT32, |
| FieldDescriptor.TYPE_SINT64 |
| ]: |
| type = 'int' |
| dft = '0' |
| elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]: |
| parts = field.type_name.split(f".{parent.name}.", 2) |
| if len(parts) == 2: |
| type = parts[1] |
| for nested_type in parent.nested_type: |
| if nested_type.name == type: |
| assert nested_type.options.map_entry |
| assert field.label == FieldDescriptor.LABEL_REPEATED |
| key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local) |
| val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local) |
| add_import(imports, 'from typing import Dict') |
| return f'Dict[{key_type}, {val_type}]', '{}', '' |
| type, desc, enum_dft = import_type(imports, field.type_name, local) |
| if isinstance(desc, EnumDescriptorProto): |
| dft_import = enum_dft |
| dft = desc.value[0].name |
| else: |
| dft = f'{type}()' |
| else: |
| raise Exception(f'TODO: {field}') |
| |
| if field.label == FieldDescriptor.LABEL_REPEATED: |
| add_import(imports, 'from typing import List') |
| type = f'List[{type}]' |
| dft = '[]' |
| |
| return type, dft, dft_import |
| |
| |
| def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]: |
| type, dft, dft_import = collect_type(imports, message, field, local) |
| oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None |
| return oneof_index, field.name, type, dft, dft_import |
| |
| |
| def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[ |
| List[Tuple[str, str, str]], |
| Dict[str, list[Tuple[str, str]]], |
| ]: |
| fields: List[Tuple[str, str, str]] = [] |
| oneof: Dict[str, list[Tuple[str, str]]] = {} |
| |
| for field in message.field: |
| idx, name, type, dft, dft_import = collect_field(imports, message, field, local) |
| if idx is not None: |
| oneof_name = message.oneof_decl[idx].name |
| oneof.setdefault(oneof_name, []) |
| oneof[oneof_name].append((name, type)) |
| else: |
| add_import(imports, dft_import) |
| fields.append((name, type, dft)) |
| |
| for oneof_name, oneof_fields in oneof.items(): |
| for name, type in oneof_fields: |
| add_import(imports, 'from typing import Optional') |
| fields.append((name, f'Optional[{type}]', 'None')) |
| |
| return fields, oneof |
| |
| |
| def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: |
| res.append(CodeGeneratorResponse.File( |
| name=file.name.replace('.proto', '_pb2.py'), |
| insertion_point=f'module_scope', |
| content=f'class {enum.name}: ...\n\n' |
| )) |
| add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper') |
| return [ |
| f'class {enum.name}(int, EnumTypeWrapper):', |
| f' pass', |
| f'', |
| *[f'{value.name}: {enum.name}' for value in enum.value], |
| '' |
| ] |
| |
| |
| def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]: |
| nested_message_lines: List[str] = [] |
| message_lines: List[str] = [f'class {message.name}(Message):'] |
| |
| add_import(imports, 'from google.protobuf.message import Message') |
| fields, oneof = collect_message(imports, message, file) |
| |
| for (name, type, _) in fields: |
| message_lines.append(f' {name}: {type}') |
| |
| args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields]) |
| if args: args = ', ' + args |
| message_lines.extend([ |
| f'', |
| f' def __init__(self{args}) -> None: ...', |
| f'' |
| ]) |
| |
| for oneof_name, oneof_fields in oneof.items(): |
| literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields)) |
| types: Set[str] = set((type for _, type in oneof_fields)) |
| if len(types) == 1: |
| type = 'Optional[' + types.pop() + ']' |
| else: |
| types.add('None') |
| type = 'Union[' + ', '.join(types) + ']' |
| |
| nested_message_lines.extend([ |
| f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):', |
| '\n'.join([f' {name}: {type}' for name, type in oneof_fields]), |
| f'', |
| ]) |
| |
| add_import(imports, 'from typing import Union') |
| add_import(imports, 'from typing_extensions import TypedDict') |
| add_import(imports, 'from typing_extensions import Literal') |
| message_lines.extend([ |
| f' @property', |
| f' def {oneof_name}(self) -> {type}: ...' |
| f'', |
| f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...' |
| f'', |
| f' def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...', |
| f'', |
| ]) |
| |
| return_variant = '\n '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields]) |
| return_asdict = '\n '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}} # type: ignore' for name, _ in oneof_fields]) |
| if return_variant: return_variant += '\n ' |
| if return_asdict: return_asdict += '\n ' |
| |
| res.append(CodeGeneratorResponse.File( |
| name=file.name.replace('.proto', '_pb2.py'), |
| insertion_point=f'module_scope', |
| content=f""" |
| def _{message.name}_{oneof_name}(self: {message.name}): |
| variant = self.{oneof_name}_variant() |
| if variant is None: return None |
| {return_variant}raise Exception('Field `{oneof_name}` not found.') |
| |
| def _{message.name}_{oneof_name}_variant(self: {message.name}): |
| return self.WhichOneof('{oneof_name}') # type: ignore |
| |
| def _{message.name}_{oneof_name}_asdict(self: {message.name}): |
| variant = self.{oneof_name}_variant() |
| if variant is None: return {{}} |
| {return_asdict}raise Exception('Field `{oneof_name}` not found.') |
| |
| setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name})) |
| setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant) |
| setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict) |
| """)) |
| |
| return message_lines + nested_message_lines |
| |
| |
| def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]: |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type, input_msg, _ = import_type(imports, method.input_type, None) |
| output_type, _, _ = import_type(imports, method.output_type, None) |
| |
| input_type_pb2, _, _ = import_type(imports, method.input_type, None) |
| output_type_pb2, _, _ = import_type(imports, method.output_type, None) |
| |
| if output_mode == 'stream': |
| if input_mode == 'stream': |
| output_type_hint = f'StreamStream[{input_type}, {output_type}]' |
| if sync: |
| add_import(imports, f'from ._utils import Sender') |
| add_import(imports, f'from ._utils import Stream') |
| add_import(imports, f'from ._utils import StreamStream') |
| else: |
| add_import(imports, f'from ._utils import AioSender as Sender') |
| add_import(imports, f'from ._utils import AioStream as Stream') |
| add_import(imports, f'from ._utils import AioStreamStream as StreamStream') |
| else: |
| output_type_hint = f'Stream[{output_type}]' |
| if sync: |
| add_import(imports, f'from ._utils import Stream') |
| else: |
| add_import(imports, f'from ._utils import AioStream as Stream') |
| else: |
| output_type_hint = output_type if sync else f'Awaitable[{output_type}]' |
| if not sync: add_import(imports, f'from typing import Awaitable') |
| |
| if input_mode == 'stream' and output_mode == 'stream': |
| add_import(imports, f'from typing import Optional') |
| return ( |
| f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n' |
| f' tx: Sender[{input_type}] = Sender()\n' |
| f' rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}( # type: ignore\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' |
| f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' |
| f' )(tx)\n' |
| f' return StreamStream(tx, rx)' |
| ).split('\n') |
| if input_mode == 'stream': |
| iterator_type = 'Iterator' if sync else 'AsyncIterator' |
| add_import(imports, f'from typing import {iterator_type}') |
| add_import(imports, f'from typing import Optional') |
| return ( |
| f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n' |
| f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' |
| f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' |
| f' )(iterator)' |
| ).split('\n') |
| else: |
| add_import(imports, f'from typing import Optional') |
| assert isinstance(input_msg, DescriptorProto) |
| input_fields, _ = collect_message(imports, input_msg, None) |
| args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields]) |
| args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields]) |
| if args: args = ', ' + args |
| return ( |
| f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n' |
| f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n' |
| f" '/{file.package}.{service.name}/{method.name}',\n" |
| f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n' |
| f' response_deserializer={output_type_pb2}.FromString # type: ignore\n' |
| f' )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout) # type: ignore' |
| ).split('\n') |
| |
| |
| def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: |
| methods = '\n\n '.join([ |
| '\n '.join( |
| generate_service_method(imports, file, service, method, sync) |
| ) for method in service.method |
| ]) |
| channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel' |
| return ( |
| f'class {service.name}:\n' |
| f' channel: {channel_type}\n' |
| f'\n' |
| f' def __init__(self, channel: {channel_type}) -> None:\n' |
| f' self.channel = channel\n' |
| f'\n' |
| f' {methods}\n' |
| ).split('\n') |
| |
| |
| def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]: |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type, _, _ = import_type(imports, method.input_type, None) |
| output_type, _, _ = import_type(imports, method.output_type, None) |
| |
| output_type_hint = output_type |
| if output_mode == 'stream': |
| if sync: |
| output_type_hint = f'Generator[{output_type}, None, None]' |
| add_import(imports, f'from typing import Generator') |
| else: |
| output_type_hint = f'AsyncGenerator[{output_type}, None]' |
| add_import(imports, f'from typing import AsyncGenerator') |
| |
| iterator_type = 'Iterator' if sync else 'AsyncIterator' |
| |
| if input_mode == 'stream': |
| iterator_type = 'Iterator' if sync else 'AsyncIterator' |
| add_import(imports, f'from typing import {iterator_type}') |
| lines = (('' if sync else 'async ') + ( |
| f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n' |
| f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' |
| f' context.set_details("Method not implemented!") # type: ignore\n' |
| f' raise NotImplementedError("Method not implemented!")' |
| )).split('\n') |
| else: |
| lines = (('' if sync else 'async ') + ( |
| f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n' |
| f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n' |
| f' context.set_details("Method not implemented!") # type: ignore\n' |
| f' raise NotImplementedError("Method not implemented!")' |
| )).split('\n') |
| if output_mode == 'stream': |
| lines.append(f' yield {output_type}() # no-op: to make the linter happy') |
| return lines |
| |
| |
| def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: |
| methods = '\n\n '.join([ |
| '\n '.join( |
| generate_servicer_method(imports, method, sync) |
| ) for method in service.method |
| ]) |
| if not methods: |
| methods = 'pass' |
| return ( |
| f'class {service.name}Servicer:\n' |
| f' {methods}\n' |
| ).split('\n') |
| |
| |
| def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]: |
| input_mode = 'stream' if method.client_streaming else 'unary' |
| output_mode = 'stream' if method.server_streaming else 'unary' |
| |
| input_type, _, _ = import_type(imports, method.input_type, None) |
| output_type, _, _ = import_type(imports, method.output_type, None) |
| |
| return ( |
| f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler( # type: ignore\n" |
| f' servicer.{method.name},\n' |
| f' request_deserializer={input_type}.FromString, # type: ignore\n' |
| f' response_serializer={output_type}.SerializeToString, # type: ignore\n' |
| f' ),\n' |
| ).split('\n') |
| |
| |
| def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]: |
| method_handlers = ' '.join([ |
| '\n '.join( |
| generate_rpc_method_handler(imports, method) |
| ) for method in service.method |
| ]) |
| server_type = 'grpc.Server' if sync else 'grpc.aio.Server' |
| return ( |
| f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n' |
| f' rpc_method_handlers = {{\n' |
| f' {method_handlers}\n' |
| f' }}\n' |
| f' generic_handler = grpc.method_handlers_generic_handler( # type: ignore\n' |
| f" '{file.package}.{service.name}', rpc_method_handlers)\n" |
| f' server.add_generic_rpc_handlers((generic_handler,)) # type: ignore\n' |
| ).split('\n') |
| |
| |
| _HEADER = '''# Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Generated python gRPC interfaces.""" |
| ''' |
| |
| _UTILS_PY = f'''{_HEADER} |
| |
| import asyncio |
| import queue |
| import grpc |
| import sys |
| |
| from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar |
| |
| |
| _T_co = TypeVar('_T_co', covariant=True) |
| _T = TypeVar('_T') |
| |
| |
| class Stream(Iterator[_T_co], grpc.RpcContext): ... |
| |
| |
| class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ... |
| |
| |
| class Sender(Iterator[_T]): |
| if sys.version_info >= (3, 8): |
| _inner: queue.Queue[_T] |
| else: |
| _inner: queue.Queue |
| |
| def __init__(self) -> None: |
| self._inner = queue.Queue() |
| |
| def __iter__(self) -> Iterator[_T]: |
| return self |
| |
| def __next__(self) -> _T: |
| return self._inner.get() |
| |
| def send(self, item: _T) -> None: |
| self._inner.put(item) |
| |
| |
| class AioSender(AsyncIterator[_T]): |
| if sys.version_info >= (3, 8): |
| _inner: asyncio.Queue[_T] |
| else: |
| _inner: asyncio.Queue |
| |
| def __init__(self) -> None: |
| self._inner = asyncio.Queue() |
| |
| def __iter__(self) -> AsyncIterator[_T]: |
| return self |
| |
| async def __anext__(self) -> _T: |
| return await self._inner.get() |
| |
| async def send(self, item: _T) -> None: |
| await self._inner.put(item) |
| |
| def send_nowait(self, item: _T) -> None: |
| self._inner.put_nowait(item) |
| |
| |
| class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext): |
| _sender: Sender[_T] |
| _receiver: Stream[_T_co] |
| |
| def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None: |
| self._sender = sender |
| self._receiver = receiver |
| |
| def send(self, item: _T) -> None: |
| self._sender.send(item) |
| |
| def __iter__(self) -> Iterator[_T_co]: |
| return self._receiver.__iter__() |
| |
| def __next__(self) -> _T_co: |
| return self._receiver.__next__() |
| |
| def is_active(self) -> bool: |
| return self._receiver.is_active() # type: ignore |
| |
| def time_remaining(self) -> float: |
| return self._receiver.time_remaining() # type: ignore |
| |
| def cancel(self) -> None: |
| self._receiver.cancel() # type: ignore |
| |
| def add_callback(self, callback: Any) -> None: |
| self._receiver.add_callback(callback) # type: ignore |
| |
| |
| class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext): |
| _sender: AioSender[_T] |
| _receiver: AioStream[_T_co] |
| |
| def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None: |
| self._sender = sender |
| self._receiver = receiver |
| |
| def __aiter__(self) -> AsyncIterator[_T_co]: |
| return self._receiver.__aiter__() |
| |
| async def __anext__(self) -> _T_co: |
| return await self._receiver.__aiter__().__anext__() |
| |
| async def send(self, item: _T) -> None: |
| await self._sender.send(item) |
| |
| def send_nowait(self, item: _T) -> None: |
| self._sender.send_nowait(item) |
| |
| def is_active(self) -> bool: |
| return self._receiver.is_active() # type: ignore |
| |
| def time_remaining(self) -> float: |
| return self._receiver.time_remaining() # type: ignore |
| |
| def cancel(self) -> None: |
| self._receiver.cancel() # type: ignore |
| |
| def add_callback(self, callback: Any) -> None: |
| self._receiver.add_callback(callback) # type: ignore |
| ''' |
| |
| |
| _FILES: List[CodeGeneratorResponse.File] = [] |
| _UTILS_FILES: Set[str] = set() |
| |
| |
| for file_name in _REQUEST.file_to_generate: |
| file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file)) |
| |
| _FILES.append(CodeGeneratorResponse.File( |
| name=file.name.replace('.proto', '_pb2.py'), |
| insertion_point=f'module_scope', |
| content='def unwrap(x):\n assert x\n return x\n' |
| )) |
| |
| pyi_imports: List[str] = [] |
| grpc_imports: List[str] = ['import grpc'] |
| grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio'] |
| |
| enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], [])) |
| messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], [])) |
| |
| services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], [])) |
| aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], [])) |
| |
| servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], [])) |
| aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], [])) |
| |
| add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], [])) |
| aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], [])) |
| |
| pyi_imports.sort() |
| grpc_imports.sort() |
| grpc_aio_imports.sort() |
| |
| pyi_imports_str: str = '\n'.join(pyi_imports) |
| grpc_imports_str: str = '\n'.join(grpc_imports) |
| grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports) |
| |
| utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py') |
| if utils_filename not in _UTILS_FILES: |
| _UTILS_FILES.add(utils_filename) |
| _FILES.extend([ |
| CodeGeneratorResponse.File( |
| name=utils_filename, |
| content=_UTILS_PY, |
| ) |
| ]) |
| |
| _FILES.extend([ |
| CodeGeneratorResponse.File( |
| name=file.name.replace('.proto', '_pb2.pyi'), |
| content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n' |
| ), |
| CodeGeneratorResponse.File( |
| name=file_name.replace('.proto', '_grpc.py'), |
| content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}' |
| ), |
| CodeGeneratorResponse.File( |
| name=file_name.replace('.proto', '_grpc_aio.py'), |
| content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}' |
| ) |
| ]) |
| |
| |
| sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString()) |