imported patch issue4815077_3002.diff
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index 9daf31e..3ffa974 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -36,6 +36,7 @@
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
+import apiclient.errors
import gflags
import httplib2
import logging
@@ -102,7 +103,7 @@
http = httplib2.Http()
http = credentials.authorize(http)
- service = build("prediction", "v1.2", http=http)
+ service = build("prediction", "v1.3", http=http)
try:
@@ -117,17 +118,22 @@
import time
# Wait for the training to complete
while True:
- status = train.get(data=FLAGS.object_name).execute()
- pprint.pprint(status)
- if 'RUNNING' != status['trainingStatus']:
+ try:
+ # We check the training job is completed. If it is not it will return an error code.
+ status = train.get(data=FLAGS.object_name).execute()
+ # Job has completed.
+ pprint.pprint(status)
break
- print 'Waiting for training to complete.'
- time.sleep(10)
+ except apiclient.errors.HttpError as error:
+ # Training job not yet completed.
+ print 'Waiting for training to complete.'
+ time.sleep(10)
+
print 'Training is complete'
# Now make a prediction using that training
body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = service.predict(body=body, data=FLAGS.object_name).execute()
+ prediction = train.predict(body=body, data=FLAGS.object_name).execute()
print 'The prediction is:'
pprint.pprint(prediction)