Fix samples so that they catch AccessTokenRefreshError.
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index 24fdf19..81f7a76 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -43,6 +43,7 @@
from apiclient.discovery import build
from oauth2client.file import Storage
+from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import OAuth2WebServerFlow
from oauth2client.tools import run
@@ -97,34 +98,39 @@
service = build("prediction", "v1.2", http=http)
- # Name of Google Storage bucket/object that contains the training data
- OBJECT_NAME = "apiclient-prediction-sample/prediction_models/languages"
+ try:
- # Start training on a data set
- train = service.training()
- start = train.insert(data=OBJECT_NAME, body={}).execute()
+ # Name of Google Storage bucket/object that contains the training data
+ OBJECT_NAME = "apiclient-prediction-sample/prediction_models/languages"
- print 'Started training'
- pprint.pprint(start)
+ # Start training on a data set
+ train = service.training()
+ start = train.insert(data=OBJECT_NAME, body={}).execute()
- import time
- # Wait for the training to complete
- while True:
- status = train.get(data=OBJECT_NAME).execute()
- pprint.pprint(status)
- if 'RUNNING' != status['trainingStatus']:
- break
- print 'Waiting for training to complete.'
- time.sleep(10)
- print 'Training is complete'
+ print 'Started training'
+ pprint.pprint(start)
- # Now make a prediction using that training
- body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = service.predict(body=body, data=OBJECT_NAME).execute()
- print 'The prediction is:'
- pprint.pprint(prediction)
+ import time
+ # Wait for the training to complete
+ while True:
+ status = train.get(data=OBJECT_NAME).execute()
+ pprint.pprint(status)
+ if 'RUNNING' != status['trainingStatus']:
+ break
+ print 'Waiting for training to complete.'
+ time.sleep(10)
+ print 'Training is complete'
+
+ # Now make a prediction using that training
+ body = {'input': {'csvInstance': ["mucho bueno"]}}
+ prediction = service.predict(body=body, data=OBJECT_NAME).execute()
+ print 'The prediction is:'
+ pprint.pprint(prediction)
+ except AccessTokenRefreshError:
+ print ("The credentials have been revoked or expired, please re-run"
+ "the application to re-authorize")
if __name__ == '__main__':
main(sys.argv)