Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 46cf96b

Browse files
afrozenatorcopybara-github
authored andcommitted
Changes to be able to use and generate proto files externally.
- Service options : deadline and fail_fast don't seem to be supported externally by gRPC, remove them. - Script added to generate the proto and service files (generate_py_proto.sh) - The generated files are added to the source code in order for Travis to work. env_service_generated_pb2.py and env_service_generated_pb2_grpc.py PiperOrigin-RevId: 264512062
1 parent 212b599 commit 46cf96b

File tree

5 files changed

+1157
-15
lines changed

5 files changed

+1157
-15
lines changed

oss_scripts/generate_py_proto.sh

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/bin/bash
2+
3+
# This script use the protoc compiler to generate the python code of the
4+
# all of our proto files.
5+
6+
7+
# Function to prepend a pylint directive to skip the generated python file.
8+
function pylint_skip_file() {
9+
local file_name=$1
10+
printf "%s\n%s" "# pylint: skip-file" "$(cat ${file_name})" > ${file_name}
11+
}
12+
13+
14+
# Setup tmp directories
15+
TMP_DIR=$(mktemp -d)
16+
TMP_TF_DIR=${TMP_DIR}/tensorflow
17+
TMP_T2T_DIR="$PWD"
18+
19+
echo "Temporary directory created: "
20+
echo ${TMP_DIR}
21+
22+
23+
TMP_T2T_PROTO_DIR="${TMP_T2T_DIR}/tensor2tensor/envs"
24+
ENV_SERVICE_PROTO="${TMP_T2T_PROTO_DIR}/env_service.proto"
25+
if [ ! -f ${ENV_SERVICE_PROTO} ]; then
26+
echo "${ENV_SERVICE_PROTO} not found."
27+
echo "Please run this script from the appropriate root directory."
28+
fi
29+
30+
# Clone tensorflow repository.
31+
git clone https://github.com/tensorflow/tensorflow.git ${TMP_TF_DIR}
32+
33+
# Install gRPC tools.
34+
pip install grpcio-tools
35+
36+
# Invoke the grpc protoc compiler on env_service.proto
37+
python -m grpc_tools.protoc \
38+
--proto_path=${TMP_TF_DIR}/ \
39+
--proto_path=${TMP_T2T_DIR}/ \
40+
--python_out=${TMP_T2T_DIR}/ \
41+
--grpc_python_out=${TMP_T2T_DIR}/ \
42+
${ENV_SERVICE_PROTO}
43+
44+
# Add pylint ignore and name the file as generated.
45+
GENERATED_ENV_SERVICE_PY="${TMP_T2T_PROTO_DIR}/env_service_generated_pb2.py"
46+
GENERATED_ENV_SERVICE_GRPC_PY="${TMP_T2T_PROTO_DIR}/env_service_generated_pb2_grpc.py"
47+
mv ${TMP_T2T_PROTO_DIR}/env_service_pb2.py ${GENERATED_ENV_SERVICE_PY}
48+
mv ${TMP_T2T_PROTO_DIR}/env_service_pb2_grpc.py ${GENERATED_ENV_SERVICE_GRPC_PY}
49+
pylint_skip_file "${GENERATED_ENV_SERVICE_PY}"
50+
pylint_skip_file "${GENERATED_ENV_SERVICE_GRPC_PY}"
51+
52+
53+
LICENSING_TEXT=$(cat <<-END
54+
# coding=utf-8
55+
# Copyright 2019 The Tensor2Tensor Authors.
56+
#
57+
# Licensed under the Apache License, Version 2.0 (the "License");
58+
# you may not use this file except in compliance with the License.
59+
# You may obtain a copy of the License at
60+
#
61+
# http://www.apache.org/licenses/LICENSE-2.0
62+
#
63+
# Unless required by applicable law or agreed to in writing, software
64+
# distributed under the License is distributed on an "AS IS" BASIS,
65+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
66+
# See the License for the specific language governing permissions and
67+
# limitations under the License.
68+
END
69+
)
70+
71+
function add_licensing_text() {
72+
local file_name=$1
73+
printf "%s\n%s" "${LICENSING_TEXT}" "$(cat ${file_name})" > ${file_name}
74+
}
75+
76+
add_licensing_text "${GENERATED_ENV_SERVICE_PY}"
77+
add_licensing_text "${GENERATED_ENV_SERVICE_GRPC_PY}"
78+

tensor2tensor/envs/__init__.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,25 @@
1515

1616
"""Environments defined in T2T. Imports here force registration."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
18+
# Proto imports.
19+
20+
21+
# pylint: disable=g-import-not-at-top,g-statement-before-imports
22+
def _get_env_service():
23+
from tensor2tensor.envs import env_service_generated_pb2 as env_service_pb2_
24+
return env_service_pb2_
25+
26+
27+
def _get_env_service_grpc():
28+
from tensor2tensor.envs import env_service_generated_pb2_grpc as env_service_pb2_grpc_
29+
return env_service_pb2_grpc_
30+
# pylint: enable=g-import-not-at-top
31+
32+
33+
env_service_pb2 = _get_env_service() # pylint: disable=invalid-name
34+
env_service_pb2_grpc = _get_env_service_grpc() # pylint: disable=invalid-name
35+
del _get_env_service, _get_env_service_grpc
36+
# pylint: enable=g-statement-before-imports
2137

2238
from gym.envs.registration import register
2339

@@ -38,3 +54,4 @@ def register_env(env_class):
3854
# TODO(afrozm): Register TicTacToeEnv the same way.
3955
# register_env(tic_tac_toe_env.TicTacToeEnv)
4056
ClientEnv = register_env(client_env.ClientEnv) # pylint: disable=invalid-name
57+

tensor2tensor/envs/env_service.proto

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
syntax = "proto3";
22
option cc_enable_arenas = true;
33

4-
package third_party.py.tensor2tensor.trax.rlax.envs;
4+
package tensor2tensor.trax.rlax.envs;
55

6-
import "third_party/tensorflow/core/framework/tensor.proto";
7-
import "third_party/tensorflow/core/framework/tensor_shape.proto";
8-
import "third_party/tensorflow/core/framework/types.proto";
6+
import "tensorflow/core/framework/tensor.proto";
7+
import "tensorflow/core/framework/tensor_shape.proto";
8+
import "tensorflow/core/framework/types.proto";
99

1010
// We use tensorflow.TensorProto to represent numpy arrays.
1111

@@ -89,30 +89,22 @@ message EnvInfoResponse {
8989
service EnvService {
9090
// Reset
9191
rpc Reset(ResetRequest) returns (ResetResponse) {
92-
option fail_fast = true;
9392
}
9493

9594
// Step
9695
rpc Step(StepRequest) returns (StepResponse) {
97-
option fail_fast = true;
9896
}
9997

10098
// Close
10199
rpc Close(CloseRequest) returns (CloseResponse) {
102-
option fail_fast = true;
103-
option deadline = 10;
104100
}
105101

106102
// Render
107103
rpc Render(RenderRequest) returns (RenderResponse) {
108-
option fail_fast = true;
109-
option deadline = 10;
110104
}
111105

112106
// Observation and Action Space.
113107
rpc GetEnvInfo(EnvInfoRequest) returns (EnvInfoResponse) {
114-
option fail_fast = true;
115-
option deadline = 10;
116108
}
117109
}
118110

0 commit comments

Comments
 (0)